{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0",
   "metadata": {},
   "source": [
    "# Solving a Brax environment using EvoTorch\n",
    "\n",
    "This notebook demonstrates how the Brax environment named `humanoid` can be solved using EvoTorch. The hyperparameters here are tuned for brax version 0.10.5.\n",
    "\n",
    "EvoTorch provides `VecGymNE`, a neuroevolution problem type that focuses on solving vectorized environments. If GPU is available, `VecGymNE` can utilize it to boost performance. In this notebook, we use `VecGymNE` to solve the `humanoid` task.\n",
    "\n",
    "For this notebook to work, the libraries [JAX](https://jax.readthedocs.io/en/latest/) and [Brax](https://github.com/google/brax) are required.\n",
    "For installing JAX, you might want to look at its [official installation instructions](https://github.com/google/jax#installation).\n",
    "After a successful installation of JAX, Brax can be installed via:\n",
    "\n",
    "```\n",
    "pip install brax\n",
    "```\n",
    "\n",
    "---\n",
    "\n",
    "Below, we import the necessary libraries."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from evotorch.algorithms import PGPE\n",
    "from evotorch.neuroevolution import VecGymNE\n",
    "from evotorch.logging import StdOutLogger, PicklingLogger\n",
    "\n",
    "import os\n",
    "import torch\n",
    "from torch import nn\n",
    "\n",
    "from datetime import datetime\n",
    "from glob import glob\n",
    "import shutil"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2",
   "metadata": {},
   "source": [
    "We now check if CUDA is available. If it is, we prepare a configuration which will tell `VecGymNE` to use a single GPU both for the population and for the fitness evaluation operations. If CUDA is not available, we will instead turn to actor-based parallelization on the CPU to boost the performance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def how_many_cuda_devices():\n",
    "    import sys\n",
    "    import subprocess as sp\n",
    "\n",
    "    instr = r\"\"\"\n",
    "import torch\n",
    "x = torch.as_tensor(1.0)\n",
    "device_count = 0\n",
    "while True:\n",
    "    try:\n",
    "        x.to(f\"cuda:{device_count}\")\n",
    "        device_count += 1\n",
    "    except Exception:\n",
    "        break\n",
    "print(device_count)\"\"\"\n",
    "\n",
    "    proc = sp.Popen([\"python\"], stdin=sp.PIPE, stdout=sp.PIPE, stderr=sp.PIPE, text=True)\n",
    "    outstr, errstr = proc.communicate(instr)\n",
    "    rcode = proc.wait()\n",
    "    if rcode == 0:\n",
    "        return int(outstr.strip())\n",
    "    else:\n",
    "        print(errstr)\n",
    "        raise RuntimeError(f\"Cannot determine number of cuda devices:\\n\\n{errstr}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4",
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_CUDA_DEVICES = how_many_cuda_devices()\n",
    "print(\"We have\", NUM_CUDA_DEVICES, \"CUDA devices\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5",
   "metadata": {},
   "outputs": [],
   "source": [
    "if torch.cuda.is_available():\n",
    "    assert NUM_CUDA_DEVICES >= 1\n",
    "    # CUDA is available. Here, we prepare GPU-specific settings.\n",
    "\n",
    "    if NUM_CUDA_DEVICES == 1:\n",
    "        # We make only one GPU visible.\n",
    "        os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
    "\n",
    "        # os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"] = \".5\" # Tell JAX to pre-allocate half of a GPU\n",
    "        os.environ[\"XLA_PYTHON_CLIENT_PREALLOCATE\"] = \"false\"  # Tell JAX to allocate on demand\n",
    "        \n",
    "        # This is the device on which the population will be stored\n",
    "        device = \"cuda:0\"\n",
    "        \n",
    "        # We do not want multi-actor parallelization when we have only 1 GPU.\n",
    "        num_actors = 0\n",
    "\n",
    "        # In the case of 1 CUDA device, there will be no distributed training\n",
    "        num_gpus_per_actor = None\n",
    "        distributed_algorithm = False\n",
    "    else:\n",
    "        # In the case of more than one CUDA devices, we enable distributed training\n",
    "\n",
    "        # os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"] = str((1 / NUM_CUDA_DEVICES) / 2)\n",
    "        os.environ[\"XLA_PYTHON_CLIENT_PREALLOCATE\"] = \"false\"  # Tell JAX to allocate on demand\n",
    "\n",
    "        device = \"cpu\"  # Main device of the population is cpu\n",
    "        num_actors = NUM_CUDA_DEVICES  # Allocate an actor per GPU\n",
    "        num_gpus_per_actor = 1  # Each actor gets assigned a GPU\n",
    "        distributed_algorithm = True  # PGPE is to work on distributed mode\n",
    "else:\n",
    "    # Since CUDA is not available, the device of the population will be cpu.\n",
    "    device = \"cpu\"\n",
    "\n",
    "    # No actor per GPU, since GPU is not available\n",
    "    num_gpus_per_actor = None\n",
    "    distributed_algorithm = False\n",
    "\n",
    "    #num_actors = \"max\"  # Use all the CPUs to speed-up the evaluations.\n",
    "    num_actors = 1\n",
    "    \n",
    "    # Because we are already using all the CPUs for actor-based parallelization,\n",
    "    # we tell XLA not to use multiple threads for its operations.\n",
    "    # (Following the suggestions at https://github.com/google/jax/issues/743)\n",
    "    os.environ[\"XLA_FLAGS\"] = \"--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1\"\n",
    "\n",
    "    # We also tell OpenBLAS and MKL to use only 1 thread for their operations.\n",
    "    os.environ[\"OPENBLAS_NUM_THREADS\"] = \"1\"\n",
    "    os.environ[\"MKL_NUM_THREADS\"] = \"1\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6",
   "metadata": {},
   "source": [
    "We now define our policy. The policy can be expressed as a string, or as an instance or as a subclass of `torch.nn.Module`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# --- A simple linear policy ---\n",
    "# policy = \"Linear(obs_length, act_length)\"\n",
    "\n",
    "\n",
    "# --- A feed-forward network ---\n",
    "policy = \"Linear(obs_length, 64) >> Tanh() >> Linear(64, act_length)\"\n",
    "\n",
    "\n",
    "# --- A feed-forward network with layer normalization ---\n",
    "# policy = (\n",
    "#     \"\"\"\n",
    "#     Linear(obs_length, 64)\n",
    "#     >> Tanh()\n",
    "#     >> LayerNorm(64, elementwise_affine=False)\n",
    "#     >> Linear(64, act_length)\n",
    "#     \"\"\"\n",
    "# )\n",
    "\n",
    "# --- A recurrent network with layer normalization ---\n",
    "# Note: in addition to RNN, LSTM is also supported\n",
    "#\n",
    "# policy = (\n",
    "#     \"\"\"\n",
    "#     RNN(obs_length, 64)\n",
    "#     >> LayerNorm(64, elementwise_affine=False)\n",
    "#     >> Linear(64, act_length)\n",
    "#     \"\"\"\n",
    "# )\n",
    "\n",
    "\n",
    "# --- A manual feed-forward network ---\n",
    "# class MyManualNetwork(nn.Module):\n",
    "#     def __init__(self):\n",
    "#         super().__init__()\n",
    "#         ...\n",
    "#\n",
    "#    def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
    "#        ...\n",
    "#\n",
    "# policy = MyManualNetwork\n",
    "\n",
    "\n",
    "# --- A manual recurrent network ---\n",
    "# class MyManualRecurrentNetwork(nn.Module):\n",
    "#     def __init__(self):\n",
    "#         super().__init__()\n",
    "#         ...\n",
    "#\n",
    "#     def forward(self, x: torch.Tensor, hidden_state = None) -> tuple:\n",
    "#         ...\n",
    "#         output_tensor = ...\n",
    "#         new_hidden_state = ...  # hidden state could be a tensor, or a tuple or dict of tensors\n",
    "#         return output_tensor, new_hidden_state\n",
    "#\n",
    "# policy = MyManualRecurrentNetwork"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8",
   "metadata": {},
   "source": [
    "Below, we instantiate our `VecGymNE` problem."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9",
   "metadata": {},
   "outputs": [],
   "source": [
    "TASK_NAME = \"brax::humanoid\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "10",
   "metadata": {},
   "source": [
    "**Note.**\n",
    "At the time of writing this (27 May 2024), the [arXiv paper of EvoTorch](https://arxiv.org/abs/2302.12600v3) reports results based on the old implementations of the brax tasks (which were the default until brax v0.1.2). In brax version v0.9.0, these old task implementations moved into the namespace `brax.v1`. If you wish to reproduce the results reported in the arXiv paper of EvoTorch, you might want to specify the environment name as `\"brax::old::humanoid\"` (where the substring `\"old::\"` causes `VecGymNE` to instantiate the environment using the namespace `brax.v1`), so that you will observe scores and execution times compatible with the ones reported in that arXiv paper. Please also see the mentioned arXiv paper for the hyperparameters used for the old brax environments."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11",
   "metadata": {},
   "outputs": [],
   "source": [
    "problem = VecGymNE(\n",
    "    env=TASK_NAME,\n",
    "    network=policy,\n",
    "    #\n",
    "    # Collect observation stats, and use those stats to normalize incoming observations\n",
    "    observation_normalization=True,\n",
    "    #\n",
    "    # In the case of the \"humanoid\" task, the agent receives an \"alive bonus\" of 5.0 for each\n",
    "    # non-terminal state it observes. In this example, we cancel out this fixed amount of\n",
    "    # alive bonus using the keyword argument `decrease_rewards_by`.\n",
    "    # The amount of alive bonus changes from task to task (some of them don't have this bonus\n",
    "    # at all).\n",
    "    decrease_rewards_by=5.0,\n",
    "    #\n",
    "    # As an alternative to giving a fixed amount of alive bonus, we now enable a scheduled\n",
    "    # alive bonus.\n",
    "    # From timestep 0 to 400, the agents will receive no alive bonus.\n",
    "    # From timestep 400 to 700, the agents will receive partial alive bonus.\n",
    "    # Beginning with timestep 700, the agents will receive full (10.0) alive bonus.\n",
    "    alive_bonus_schedule=(400, 700, 10.0),\n",
    "    device=device,\n",
    "    num_actors=num_actors,\n",
    "    num_gpus_per_actor=num_gpus_per_actor,\n",
    ")\n",
    "\n",
    "problem, problem.solution_length"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "12",
   "metadata": {},
   "source": [
    "Initialize a PGPE to work on the problem.\n",
    "\n",
    "Note: If you receive memory allocation error from the GPU driver, you might want to try again with:\n",
    "- a decreased `popsize`\n",
    "- a policy with decreased hidden size and/or number of layers (in case the policy is a neural network)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13",
   "metadata": {},
   "outputs": [],
   "source": [
    "RADIUS = 2.25\n",
    "MAX_SPEED = RADIUS / 15\n",
    "CENTER_LR = MAX_SPEED * 0.75\n",
    "\n",
    "POPSIZE = 4000\n",
    "NUM_GENERATIONS = 1000\n",
    "SAVE_INTERVAL = 20\n",
    "\n",
    "# Instantiate a PGPE using the hyperparameters prepared above\n",
    "searcher = PGPE(\n",
    "    problem,\n",
    "    popsize=POPSIZE,\n",
    "    num_interactions=(POPSIZE * 1000 * 0.75),\n",
    "    radius_init=RADIUS,\n",
    "    center_learning_rate=CENTER_LR,\n",
    "    optimizer=\"clipup\",\n",
    "    optimizer_config={\"max_speed\": MAX_SPEED},\n",
    "    stdev_learning_rate=0.1,\n",
    "    distributed=distributed_algorithm,\n",
    ")\n",
    "\n",
    "searcher"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14",
   "metadata": {},
   "outputs": [],
   "source": [
    "task_name_for_saving = TASK_NAME.split(\"::\")[-1]\n",
    "now_as_str = datetime.now().strftime(\"%Y-%m-%d-%H.%M.%S\")\n",
    "OUTPUT_DIR = f\"{task_name_for_saving}_{now_as_str}_{os.getpid()}\"\n",
    "\n",
    "print(\"PicklingLogger will save into\", OUTPUT_DIR)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15",
   "metadata": {},
   "source": [
    "We register two loggers for our PGPE instance.\n",
    "\n",
    "- **StdOutLogger:** A logger which will print out the status of the optimization.\n",
    "- **PicklingLogger:** A logger which will periodically save the latest result into a pickle file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16",
   "metadata": {},
   "outputs": [],
   "source": [
    "_ = StdOutLogger(searcher)\n",
    "pickler = PicklingLogger(searcher, interval=SAVE_INTERVAL, directory=OUTPUT_DIR)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "17",
   "metadata": {},
   "source": [
    "We are now ready to start the evolutionary search."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18",
   "metadata": {},
   "outputs": [],
   "source": [
    "for generation in range(1, 1 + NUM_GENERATIONS):\n",
    "    t_before_step = datetime.now()\n",
    "    searcher.step()\n",
    "    t_after_step = datetime.now()\n",
    "    print(\"Elapsed:\", (t_after_step - t_before_step).total_seconds())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"The run is finished.\")\n",
    "print(\"The pickle file that contains the latest result is:\")\n",
    "print(pickler.last_file_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "20",
   "metadata": {},
   "source": [
    "See the notebook [Brax_Experiments_Visualization.ipynb](Brax_Experiments_Visualization.ipynb) for visualizing the pickle files generated by this notebook."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
